Skip to content

refactor: port CSE module to support JAX via xp parameter#447

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/cse-jax-port
May 26, 2026
Merged

refactor: port CSE module to support JAX via xp parameter#447
Jammy2211 merged 1 commit into
mainfrom
feature/cse-jax-port

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Port the CSE (Cored Steep Ellipsoid) module to support JAX by threading xp=np through all forward-path methods, mirroring the already JAX-ready MGE module. This unblocks JAX JIT compilation for any profile that uses CSE decomposition (primarily NFW). Phase 2 of the mass profiles refactoring epic (#445).

API Changes

Added xp=np keyword argument to four MassProfileCSE methods: convergence_cse_1d_from, deflections_via_cse_from, _convergence_2d_via_cse_from, _deflections_2d_via_cse_from. Existing callers passing no xp argument are unaffected (default is np). The decomposition solver (_decompose_convergence_via_cse_from) stays NumPy-only — it runs before JIT tracing. See full details below.

Test Plan

  • pytest test_autogalaxy/profiles/mass/ — 406 passed, 0 failed
  • profiles_jit.py JAX three-step test for CSE-based profiles (follow-up on autolens_workspace_test)
Full API Changes (for automation & release notes)

Changed Signature

  • MassProfileCSE.convergence_cse_1d_from(grid_radii, core_radius)convergence_cse_1d_from(grid_radii, core_radius, xp=np)
  • MassProfileCSE.deflections_via_cse_from(term1, term2, term3, term4, axis_ratio_squared, core_radius) → adds xp=np
  • MassProfileCSE._convergence_2d_via_cse_from(grid_radii, **kwargs) → adds xp=np
  • MassProfileCSE._deflections_2d_via_cse_from(grid, **kwargs) → adds xp=np

Changed Behaviour

  • deflections_via_cse_from now uses xp.sqrt and xp.vstack instead of hardcoded np.sqrt / np.vstack
  • NFW.deflections_2d_via_cse_from and NFW.convergence_2d_via_cse_from now thread xp=xp to CSE base methods

Migration

No migration needed — all new xp parameters default to np, preserving existing behaviour.

🤖 Generated with Claude Code

Thread xp=np through all MassProfileCSE methods (convergence_cse_1d_from,
deflections_via_cse_from, _convergence_2d_via_cse_from,
_deflections_2d_via_cse_from) and replace np.sqrt/np.vstack with xp
equivalents. Thread xp=xp through NFW CSE callers. The decomposition
solver (_decompose_convergence_via_cse_from) stays NumPy-only as it is
a one-time setup computation, not part of the JIT-traced forward pass.

Phase 2 of mass profiles refactoring epic (#445).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 26, 2026
@Jammy2211 Jammy2211 merged commit e6ab72c into main May 26, 2026
6 checks passed
@Jammy2211 Jammy2211 deleted the feature/cse-jax-port branch May 26, 2026 09:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant